{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "<a href=\"https://colab.research.google.com/github/MisaOgura/flashtorch/blob/master/examples/create_saliency_maps_for_torchvision_models_colab.ipynb\" target=\"_parent\">\n",
    "    <img src=\"https://colab.research.google.com/assets/colab-badge.svg\"/>\n",
    "</a>"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Using `flashtorch.saliency.backprop` with `torchvision.models`\n",
    "---\n",
    "\n",
    "Demo of creating saliency maps for a wide range of models available from `torchvision.models`.\n",
    "\n",
    "It assumes that you are familiar with how to use `flashtorch.saliency.backprop` to create saliency maps. For step-by-step instructions on how to create saliency maps, head to [**this one**](https://colab.research.google.com/github/MisaOgura/flashtorch/blob/master/examples/visualize_saliency_with_backprop_colab.ipynb) instead.\n",
    "\n",
    "---\n",
    "\n",
    "❗This notebook is for those who are using this notebook in **Google Colab**.\n",
    "\n",
    "If you aren't on Google Colab already, please head to the Colab version of this notebook **[here](https://colab.research.google.com/github/MisaOgura/flashtorch/blob/master/examples/create_saliency_maps_for_torchvision_models_colab.ipynb)** to execute."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 0. Set up"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Install flashtorch\n",
    "\n",
    "!pip install flashtorch"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Download example images\n",
    "\n",
    "!mkdir -p images\n",
    "\n",
    "!wget -nv \\\n",
    "    https://github.com/MisaOgura/flashtorch/raw/master/examples/images/peacock.jpg \\\n",
    "    -P /content/images"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "%matplotlib inline\n",
    "%config InlineBackend.figure_format = 'retina'\n",
    "\n",
    "import matplotlib.pyplot as plt\n",
    "import torchvision.models as models\n",
    "\n",
    "from flashtorch.utils import apply_transforms, load_image\n",
    "from flashtorch.saliency import Backprop"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 1. Load an image "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "image = load_image('/content/images/peacock.jpg')\n",
    "\n",
    "plt.imshow(image)\n",
    "plt.title('Original image')\n",
    "plt.axis('off');\n",
    "\n",
    "# Prepare inputs to be used later\n",
    "\n",
    "img = apply_transforms(image)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 2. Use pre-trained models from `torchvision.models`"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# A wrapper to help out visualization\n",
    "\n",
    "def visualize_helper(model_module, tensor=img, k=84):\n",
    "    model = model_module(pretrained=True)\n",
    "    backprop = Backprop(model)\n",
    "    backprop.visualize(tensor, k, guided=True)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### AlexNet"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "visualize_helper(models.alexnet)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### VGG-19"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "visualize_helper(models.vgg19)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### VGG-19 with batch normalization"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "visualize_helper(models.vgg19_bn)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### ResNet-50"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "visualize_helper(models.resnet50)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### ResNet-152"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "visualize_helper(models.resnet152)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### SqueezeNet 1.1"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "visualize_helper(models.squeezenet1_1)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### Densenet-121"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "visualize_helper(models.densenet121)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### Densenet-201"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "visualize_helper(models.densenet201)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### Inception v3"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Image must be 299x299 for Inception models\n",
    "\n",
    "img_299 = apply_transforms(image, size=299)\n",
    "visualize_helper(models.inception_v3, img_299)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### GoogleNet"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "visualize_helper(models.googlenet)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### ShuffleNet V2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "visualize_helper(models.shufflenet_v2_x1_0)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### MobileNet V2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "visualize_helper(models.mobilenet_v2)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### ResNeXt-50-32x4d"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "visualize_helper(models.resnext50_32x4d)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### Wide ResNet-50-2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "visualize_helper(models.wide_resnet50_2)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### MNASNet 1.0"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "visualize_helper(models.mnasnet1_0)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.9"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
